{% extends 'base.attached.class' %}

{% block content %}
var {{ class_name }} = function(priors, sigmas, thetas) {

    this.priors = priors;
    this.sigmas = sigmas;
    this.thetas = thetas;

    var _findMax = function(nums) {
        var i = 0, l = nums.length, idx = 0;
        for (; i < l; i++) {
            idx = nums[i] > nums[idx] ? i : idx;
        }
        return idx;
    };

    var _logSumExp = function(nums) {
        var i, l = nums.length;
        var max = Math.max(...nums);
        var sum = 0.;
        for (i = 0; i < l; i++) {
            sum += Math.exp(nums[i] - max);
        }
        return max + Math.log(sum);
    };

    var _compute = function(features, priors, sigmas, thetas) {
        var i, il, j, jl, sum, nij,
            likelihoods = new Array(sigmas.length);
        for (i = 0, il = sigmas.length; i < il; i++) {
            sum = 0.;
            for (j = 0, jl = sigmas[0].length; j < jl; j++) {
                sum += Math.log(2. * Math.PI * sigmas[i][j]);
            }
            nij = -0.5 * sum;
            sum = 0.;
            for (j = 0, jl = sigmas[0].length; j < jl; j++) {
                sum += Math.pow(features[j] - thetas[i][j], 2.) / sigmas[i][j];
            }
            nij -= 0.5 * sum;
            likelihoods[i] = Math.log(priors[i]) + nij;
        }
        return likelihoods;
    };

    this.predict = function(features) {
        return _findMax(_compute(features, this.priors, this.sigmas, this.thetas));
    };

    this.predictProba = function(features) {
        var jll = _compute(features, this.priors, this.sigmas, this.thetas);
        var sum = _logSumExp(jll);
        for (i = 0; i < jll.length; i++) {
            jll[i] = Math.exp(jll[i] - sum);
        }
        return jll;
    };

};

var main = function () {
    if (typeof process !== 'undefined' && typeof process.argv !== 'undefined') {
        if (process.argv.length - 2 !== {{ n_features }}) {
            var IllegalArgumentException = function(message) {
                this.message = message;
                this.name = "IllegalArgumentException";
            }
            throw new IllegalArgumentException("You have to pass {{ n_features }} features.");
        }
    }

    // Features:
    var features = process.argv.slice(2);
    for (var i = 0; i < features.length; i++) {
        features[i] = parseFloat(features[i]);
    }

    // Model data:
    {{ priors }}
    {{ sigmas }}
    {{ thetas }}

    // Estimator:
    var clf = new {{ class_name }}(priors, sigmas, thetas);

    {% if is_test or to_json %}
    // Get JSON:
    console.log(JSON.stringify({
        "predict": clf.predict(features),
        "predict_proba": clf.predictProba(features)
    }));
    {% else %}
    // Get class prediction:
    var prediction = clf.predict(features);
    console.log("Predicted class: #" + prediction);

    // Get class probabilities:
    var probabilities = clf.predictProba(features);
    for (var i = 0; i < probabilities.length; i++) {
        console.log("Probability of class #" + i + " : " + probabilities[i]);
    }
    {% endif %}
}

if (require.main === module) {
    main();
}
{% endblock %}